{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "from tensorflow import keras\n",
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline\n",
    "import numpy as np\n",
    "import glob\n",
    "import os\n",
    "import pandas as pd\n",
    "\n",
    "import efficientnet.tfkeras as efn"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_image_path = glob.glob('训练集\\\\train\\\\*\\\\*.png')\n",
    "labels = [int(i.split('\\\\')[2]) for i in train_image_path]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Step1 使用 tf.data.Dataset.from_tensor_slices 进行加载\n",
    "image_ds = tf.data.Dataset.from_tensor_slices((train_image_path,labels))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "val_count = int(len(labels)*0.2)\n",
    "train_count = len(labels)-val_count"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "def load_preprocess_image(path,label):\n",
    "    image = tf.io.read_file(path)\n",
    "    image = tf.image.decode_jpeg(image,channels=3)\n",
    "    image = tf.image.resize(image,[300,300])\n",
    "    image = tf.image.random_crop(image,[260,260,3])\n",
    "    image = tf.image.random_flip_left_right(image)\n",
    "    image = tf.image.random_flip_up_down(image)\n",
    "    image = tf.image.random_brightness(image,0.2)\n",
    "    \n",
    "    image = tf.cast(image,tf.float32)\n",
    "    image = image/255\n",
    "    label = tf.reshape(label,[1])\n",
    "    return image,label"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "def load_preprocess_image_test(path,label):\n",
    "    image = tf.io.read_file(path)\n",
    "    image = tf.image.decode_jpeg(image,channels=3)\n",
    "    image = tf.image.resize(image,[260,260])\n",
    "    image = tf.cast(image,tf.float32)\n",
    "    image = image/255\n",
    "    label = tf.reshape(label,[1])\n",
    "    return image,label"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 选取验证集\n",
    "image_train_ds = image_ds.skip(val_count)\n",
    "image_val_ds = image_ds.take(val_count)\n",
    "\n",
    "# Step3 预处理 (预处理函数在下面)\n",
    "AUTOTUNE = tf.data.experimental.AUTOTUNE\n",
    "image_train_ds = image_train_ds.map(load_preprocess_image,num_parallel_calls=AUTOTUNE)\n",
    "image_val_ds = image_val_ds.map(load_preprocess_image_test,num_parallel_calls=AUTOTUNE)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "BATCH_SIZE = 32\n",
    "# shuffle 打乱数据 batch设置 batch size repeat设置迭代次数(迭代2次) test数据集不需要\n",
    "image_train_ds = image_train_ds.repeat(2).shuffle(train_count).batch(BATCH_SIZE)\n",
    "image_train_ds = image_train_ds.prefetch(AUTOTUNE)#预取,GPU，CPU加速\n",
    "\n",
    "image_val_ds = image_val_ds.batch(BATCH_SIZE)\n",
    "image_val_ds = image_val_ds.prefetch(AUTOTUNE)#预取"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "covn_base = efn.EfficientNetB3(weights='imagenet',\n",
    "                               input_shape=(260,260,3),\n",
    "                               include_top=False,\n",
    "                               pooling='avg')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "# covn_base.summary()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = keras.Sequential()\n",
    "model.add(covn_base)\n",
    "model.add(keras.layers.Dropout(0.2))\n",
    "model.add(keras.layers.Dense(1024,activation='relu'))\n",
    "model.add(keras.layers.Dropout(0.2))\n",
    "model.add(keras.layers.Dense(512,activation='relu'))\n",
    "model.add(keras.layers.Dense(24,activation='sigmoid'))\n",
    "\n",
    "covn_base.trainable = False #设置权重参数不可训练"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/8\n",
      "46/46 [==============================] - 200s 4s/step - loss: 2.9939 - acc: 0.0768 - val_loss: 11.0184 - val_acc: 0.0000e+00\n",
      "Epoch 2/8\n",
      "46/46 [==============================] - 229s 5s/step - loss: 2.9694 - acc: 0.0808 - val_loss: 12.8389 - val_acc: 0.0000e+00\n",
      "Epoch 3/8\n",
      " 2/46 [>.............................] - ETA: 34s - loss: 2.9774 - acc: 0.0000e+00WARNING:tensorflow:Your input ran out of data; interrupting training. Make sure that your dataset or generator can generate at least `steps_per_epoch * epochs` batches (in this case, 368 batches). You may need to use the repeat() function when building your dataset.\n",
      " 2/46 [>.............................] - 41s 21s/step - loss: 2.9774 - acc: 0.0000e+00 - val_loss: 12.8920 - val_acc: 0.0000e+00\n"
     ]
    }
   ],
   "source": [
    "#编译\n",
    "model.compile(optimizer=keras.optimizers.Adam(lr=0.001),\n",
    "             loss = 'sparse_categorical_crossentropy',\n",
    "             metrics=['acc'])\n",
    "\n",
    "history = model.fit(\n",
    "    image_train_ds,\n",
    "    steps_per_epoch=train_count//BATCH_SIZE,\n",
    "    epochs=8,\n",
    "    validation_data=image_val_ds,\n",
    "    validation_steps=val_count//BATCH_SIZE\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 9/16\n",
      "46/46 [==============================] - 231s 5s/step - loss: 2.8240 - acc: 0.1345 - val_loss: 12.0627 - val_acc: 0.0000e+00\n",
      "Epoch 10/16\n",
      "46/46 [==============================] - 215s 5s/step - loss: 2.6769 - acc: 0.1773 - val_loss: 11.6352 - val_acc: 0.0000e+00\n",
      "Epoch 11/16\n",
      " 2/46 [>.............................] - ETA: 35s - loss: 2.4950 - acc: 0.0682WARNING:tensorflow:Your input ran out of data; interrupting training. Make sure that your dataset or generator can generate at least `steps_per_epoch * epochs` batches (in this case, 368 batches). You may need to use the repeat() function when building your dataset.\n",
      " 2/46 [>.............................] - 49s 24s/step - loss: 2.4950 - acc: 0.0682 - val_loss: 11.6590 - val_acc: 0.0000e+00\n"
     ]
    }
   ],
   "source": [
    "covn_base.trainable=True\n",
    "fine_tune_at = -3\n",
    "for layer in covn_base.layers[:fine_tune_at]:\n",
    "    layer.trainable = False\n",
    "\n",
    "model.compile(optimizer=keras.optimizers.Adam(0.0005),\n",
    "             loss = 'sparse_categorical_crossentropy',\n",
    "             metrics=['acc'])\n",
    "\n",
    "initial_epochs=8\n",
    "fine_tune_epochs=8\n",
    "total_epoch = initial_epochs+fine_tune_epochs\n",
    "\n",
    "history = model.fit(\n",
    "    image_train_ds,\n",
    "    steps_per_epoch=train_count//BATCH_SIZE,\n",
    "    epochs=total_epoch,\n",
    "    initial_epoch=initial_epochs,\n",
    "    validation_data=image_val_ds,\n",
    "    validation_steps=val_count//BATCH_SIZE\n",
    ")\n",
    "\n",
    "model.save('mstz_model_EfficientNetB52.h5')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "#加载图片\n",
    "def load_preprocess_images(path):\n",
    "    image = tf.io.read_file(path)\n",
    "    image = tf.image.decode_jpeg(image,channels=3)\n",
    "    image = tf.image.resize(image,[260,260])\n",
    "    image = tf.cast(image,tf.float32)\n",
    "    image = image/255\n",
    "    image = tf.expand_dims(image,0)\n",
    "    return image"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_image_path = glob.glob('测试集/*.png')\n",
    "\n",
    "test_image_path.sort(key=lambda x:int(x.split('\\\\')[-1].split('.')[0][1:]))\n",
    "\n",
    "images = [load_preprocess_images(image_path) for image_path in test_image_path]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      ".........................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................."
     ]
    }
   ],
   "source": [
    "image_count = len(images)\n",
    "values = []\n",
    "result_dict = {}\n",
    "\n",
    "for i in range(image_count):\n",
    "    pred = model.predict(images[i])\n",
    "    values.append(np.argmax(pred))\n",
    "    print('.',end='')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "#写文件\n",
    "with open('result_efficientB2.csv','w',encoding='utf-8') as f:\n",
    "    f.write('image_id,category_id\\n')\n",
    "    [f.write('{0},{1}\\n'.format(key.split('\\\\')[1][:-4]+'.png', value)) for (key,value) in zip(test_image_path,values)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
